

from argparse import ArgumentParser# 用于解析命令行参数
from threading import Thread# 用于创建线程以异步处理文本生成

import gradio as gr# 用于创建交互式Web界面
import torch
# 用于加载和使用预训练的因果语言模型、分词器和文本迭代器流
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer


DEFAULT_CKPT_PATH = "../Qwen2-7B-Instruct"# 设置模型路径

# 解析命令行参数
def _get_args():
    parser = ArgumentParser(description="Qwen2-Instruct web chat demo.")
    parser.add_argument(
        "-c",
        "--checkpoint-path",
        type=str,
        default=DEFAULT_CKPT_PATH,
        help="Checkpoint name or path, default to %(default)r",
    )
    parser.add_argument(
        "--cpu-only", action="store_true", help="Run demo with CPU only"
    )

    parser.add_argument(
        "--share",
        action="store_true",
        default=False,
        help="Create a publicly shareable link for the interface.",
    )
    parser.add_argument(
        "--inbrowser",
        action="store_true",
        default=False,
        help="Automatically launch the interface in a new tab on the default browser.",
    )
    parser.add_argument(
        "--server-port", type=int, default=8000, help="Demo server port."
    )
    parser.add_argument(
        "--server-name", type=str, default="127.0.0.1", help="Demo server name."
    )

    args = parser.parse_args()
    return args

# 加载模型和分词器
def _load_model_tokenizer(args):
    tokenizer = AutoTokenizer.from_pretrained(
        args.checkpoint_path,
        resume_download=False,
    )

    if args.cpu_only:
        device_map = "cpu"
    else:
        device_map = "auto"
    
    device_map = "cpu"

    model = AutoModelForCausalLM.from_pretrained(
        args.checkpoint_path,
        torch_dtype="auto",
        device_map=device_map,
        resume_download=False,
    ).eval()
    model.generation_config.max_new_tokens = 2048  # For chat.

    return model, tokenizer

# 异步处理聊天文本生成
def _chat_stream(model, tokenizer, query, history):
    conversation = []
    for query_h, response_h in history:
        conversation.append({"role": "user", "content": query_h})
        conversation.append({"role": "assistant", "content": response_h})
    conversation.append({"role": "user", "content": query})
    input_text = tokenizer.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        tokenize=False,
    )
    inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
    streamer = TextIteratorStreamer(
        tokenizer=tokenizer, skip_prompt=True, timeout=60.0, skip_special_tokens=True
    )
    generation_kwargs = {
        **inputs,
        "streamer": streamer,
    }
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    for new_text in streamer:
        yield new_text

# 清理缓存
def _gc():
    import gc

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# 启动演示界面
def _launch_demo(args, model, tokenizer):
    def predict(_query, _chatbot, _task_history):# 定义预测函数，用于处理用户输入并生成响应
        print(f"User: {_query}")
        _chatbot.append((_query, ""))# 添加用户输入到聊天历史
        full_response = ""# 初始化完整响应
        response = ""# 初始化当前响应
        for new_text in _chat_stream(model, tokenizer, _query, history=_task_history):# 异步生成文本
            response += new_text# 更新当前响应
            _chatbot[-1] = (_query, response)# 更新聊天历史中的当前响应

            yield _chatbot# 返回更新后的聊天历史
            full_response = response# 更新完整响应

        print(f"History: {_task_history}")
        _task_history.append((_query, full_response))# 添加完整对话到历史记录
        print(f"Qwen: {full_response}")
    # 定义重试函数，用于重新生成响应
    def regenerate(_chatbot, _task_history):
        if not _task_history:
            yield _chatbot
            return
        item = _task_history.pop(-1)# 获取最近的一次对话
        _chatbot.pop(-1)# 从聊天历史中移除最近的一次对话
        yield from predict(item[0], _chatbot, _task_history)# 重新生成响应
    # 定义重置用户输入函数
    def reset_user_input():
        return gr.update(value="")# 清空用户输入框
    # 定义重置状态函数，用于清空历史记录和缓存
    def reset_state(_chatbot, _task_history):
        _task_history.clear()# 清空历史记录
        _chatbot.clear()# 清空聊天历史
        _gc()
        return _chatbot
    # 使用Gradio创建Web界面
    with gr.Blocks() as demo:
        gr.Markdown("""\
<p align="center"><img src="https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/assets/logo/qwen2.5_logo.png" style="height: 120px"/><p>""")
        gr.Markdown(
            """\
<center><font size=3>This WebUI is based on Qwen2.5-Instruct, developed by Alibaba Cloud. \
(本WebUI基于Qwen2.5-Instruct打造，实现聊天机器人功能。)</center>"""
        )
        gr.Markdown("""\
<center><font size=4>
Qwen2.5-7B-Instruct <a href="https://modelscope.cn/models/qwen/Qwen2.5-7B-Instruct/summary">🤖 </a> | 
<a href="https://huggingface.co/Qwen/Qwen2.5-7B-Instruct">🤗</a>&nbsp ｜ 
Qwen2.5-32B-Instruct <a href="https://modelscope.cn/models/qwen/Qwen2.5-32B-Instruct/summary">🤖 </a> | 
<a href="https://huggingface.co/Qwen/Qwen2.5-32B-Instruct">🤗</a>&nbsp ｜ 
Qwen2.5-72B-Instruct <a href="https://modelscope.cn/models/qwen/Qwen2.5-72B-Instruct/summary">🤖 </a> | 
<a href="https://huggingface.co/Qwen/Qwen2.5-72B-Instruct">🤗</a>&nbsp ｜ 
&nbsp<a href="https://github.com/QwenLM/Qwen2.5">Github</a></center>""")

        chatbot = gr.Chatbot(label="Qwen", elem_classes="control-height")# 创建聊天机器人组件
        query = gr.Textbox(lines=2, label="Input")# 创建用户输入框
        task_history = gr.State([])# 创建用于存储对话历史的状态变量

        with gr.Row():
            empty_btn = gr.Button("🧹 Clear History (清除历史)")
            submit_btn = gr.Button("🚀 Submit (发送)")
            regen_btn = gr.Button("🤔️ Regenerate (重试)")

        submit_btn.click(
            predict, [query, chatbot, task_history], [chatbot], show_progress=True
        )
        submit_btn.click(reset_user_input, [], [query])
        empty_btn.click(
            reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True
        )
        regen_btn.click(
            regenerate, [chatbot, task_history], [chatbot], show_progress=True
        )

        gr.Markdown("""\
<font size=2>Note: This demo is governed by the original license of Qwen2.5. \
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
including hate speech, violence, pornography, deception, etc. \
(注：本演示受Qwen2.5的许可协议限制。我们强烈建议，用户不应传播及不应允许他人传播以下内容，\
包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")

    demo.queue().launch(
        share=args.share,
        inbrowser=args.inbrowser,
        server_port=args.server_port,
        server_name=args.server_name,
    )


def main():
    args = _get_args()

    model, tokenizer = _load_model_tokenizer(args)

    _launch_demo(args, model, tokenizer)


if __name__ == "__main__":
    main()
